import math
from typing import List, Optional

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM


class AveragePPLDetector:
    """
    Average-PPL detector
    Perplexity = exp(mean NLL over *all* non-pad tokens in the prompt).
    """

    def __init__(
        self,
        model_name: str,
        device: Optional[str] = None,
    ):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        # Load tokenizer + model
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, use_fast=True, trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, trust_remote_code=True, torch_dtype=torch.float16
        )

        if self.tokenizer.pad_token_id is None:
            try:
                print("Adding pad token to tokenizer...")
                self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
                self.model.resize_token_embeddings(len(self.tokenizer))
                self.tokenizer.padding_side = "right"
            except ValueError:
                print("Tokenizer does not support adding special tokens.")
                self.tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'})
                self.tokenizer.padding_side = "right"

        self.model.to(self.device).eval()

    # ---------- internal helpers ---------- #

    def _compute_nll_batch(self, texts: List[str]):
        enc = self.tokenizer(
            texts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            add_special_tokens=True,
        ).to(self.device)

        input_ids = enc.input_ids
        attention_mask = enc.attention_mask

        with torch.no_grad():
            logits = self.model(**enc).logits  # (B, T, V)

        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()

        pad_id = self.tokenizer.pad_token_id
        mask = shift_labels != pad_id
        shift_labels = shift_labels.clone()
        shift_labels[~mask] = -100  # ignore in CE

        B, Tm1, V = shift_logits.size()
        flat_logits = shift_logits.view(-1, V)
        flat_labels = shift_labels.view(-1)

        flat_loss = F.cross_entropy(
            flat_logits,
            flat_labels,
            reduction='none',
            ignore_index=-100,
        )
        loss = flat_loss.view(B, Tm1)          # (B, T-1)
        lengths = mask.sum(dim=1).tolist()     # number of valid tokens per sample

        return loss, lengths

    def _average_ppl(self, nll_vec: torch.Tensor) -> float:
        """exp(mean NLL) over the whole sequence."""
        return float(torch.exp(nll_vec.mean()).item())

    # ---------- public API ---------- #

    def score(self, text: str) -> float:
        nll_batch, lengths = self._compute_nll_batch([text])
        return self._average_ppl(nll_batch[0, : lengths[0]])

    def score_batch(
        self,
        texts: List[str],
        batch_size: int = 8,
    ) -> List[float]:
        results: List[float] = []
        for i in range(0, len(texts), batch_size):
            chunk = texts[i : i + batch_size]
            nll_batch, lengths = self._compute_nll_batch(chunk)
            for b_idx, L in enumerate(lengths):
                results.append(self._average_ppl(nll_batch[b_idx, :L]))
        return results